# Introduction ======

#We are going to put the code used to make figure 6 (EEG - Psychiatric Symptom mediation) here

# Set up ----

## Load packages and metadata ---------

#Load packages
pacman::p_load(tidyverse,patchwork,ghibli,eegUtils,ggdag)

#Get our default settings and functions
source("./eLife Submission Scripts/Analysis-Common-Utilities.R")

# Figure 6 =========

## Fig 6A: DAG =======

dag_coords <-
  tibble(name = c("G", "E", "P"),
         x    = c( 0 ,  1 ,  2 ),
         y    = c( 0 ,  1 ,  0 ))


p_dag <-  
  dagify(P ~ G,
         P ~ E,
         E ~ G,
         coords = dag_coords) %>%
  dag_label(labels = c("P"  = "Psychiatric Symptoms",
                       "G"  = "Genotype",
                       "E"  = "EEG Measure")) %>%
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_point(color = "grey30",
                 alpha = 1/2, size = 16, show.legend = F) +
  geom_dag_text(color = "black") +
  geom_dag_edges(edge_width = 0.75) +
  annotate(x = 1, y = -0.1,geom = "text", 
           label = paste("direct effect")) +
  annotate(x = 1, y = 1.3,geom = "text", 
           label = paste("mediated effect")) +
  scale_x_continuous(NULL, breaks = NULL, expand = c(.1, .1)) +
  scale_y_continuous(NULL, breaks = NULL, expand = c(.1, .1)) +
  theme_dag()


ggsave("./Figures/mediation_dag.pdf",plot = p_dag, width = 6, height = 6, units = "cm")


## Fig 6B: Mediation Topoplots =======

### Load the mediation models =========

#See Analysis-fitMediationModels.R for the code that generates the mediation models
p_med_60 = 
  readr::read_rds("./elife Submission Data/sleep_study_mediation_plot.rds")


### Prepare the model data for plotting --------

#Mash together and extract summary data from the mediation models

eeg_set     = c("cons","spin_amp","so_amp","itpc_overlap_mag_z")
outcome_set = c("sleepprobsall","adhdsymsall","anyanxsymsall","asqtotalsymsall","pesall","fsiqall")


#We need two datasets

# - the mediated effect/proportion mediated data for plotting the topoplot surface
# - the significant clusters for highlighting

#First get the proportion mediated data
p_med_plot = 
  p_med_60 |>
  mutate(outcome  = factor(outcome , levels = outcome_set),
         mediator = factor(mediator, levels = eeg_set)) |>
  select(-data)|>
  unnest(clus) |> 
  filter(Term == "Proportions Mediated") |>
  select(outcome,mediator,data) |>
  unnest(data) |>
  select(outcome, mediator,g_v,statistic) |>
  rename(electrode = g_v) |>
  left_join(topo,by = "electrode")

#Now lets get the set of all the cluster-corrected significant electrodes
p_clus_plot = 
  p_med_60 |>
  select(-data) |>
  unnest(clus) |>
  select(-data) |>
  unnest(clus) |>
  filter(p.value < 0.05) |>
  filter(Term %in% c("Average Mediated Effect","Total Effect")) |>
  select(outcome,mediator,Term,electrodes) |>
  pivot_wider(names_from = Term,values_from = electrodes,values_fn = list) |>
  mutate(emp = map_lgl(`Average Mediated Effect`,is.null)) |>
  filter(!emp) |>
  select(-emp)|> 
  mutate(me = map(`Average Mediated Effect`, unlist),
         te = map(`Total Effect`   , unlist)) |>
  select(-c(`Average Mediated Effect`,`Total Effect`))|> 
  mutate(sig_clus = map2(me,te,intersect)) |>
  mutate(sig_l    = map_lgl(sig_clus,~ifelse(length(.x) > 0,TRUE,FALSE))) |>
  filter(sig_l) |>
  select(outcome,mediator,sig_clus) |>
  unnest(sig_clus) |>
  rename(electrode   = sig_clus) |>
  left_join(topo,by = "electrode") |>
  mutate(statistic = 1) |>
  mutate(outcome  = factor(outcome , levels = outcome_set),
         mediator = factor(mediator, levels = eeg_set))


### Make topoplots =============================================================

p_med = 
  p_med_plot |>
  mutate(statistic = abs(statistic)) |>
  ggplot(aes(x = x,
             y = y,
             z = statistic,
             fill = statistic,
             label = electrode)) +
  geom_topo(grid_res = 200,
            colour = "white",
            size = 0.1,
            interp_limit = "head",
            chan_markers = "point",
            chan_size = 0.25,
            head_size = 0.5,
            method = "gam", breaks = 10) + 
  geom_point(data = p_clus_plot,
             aes(x = x, y = y), colour = "black",fill = "black", size = 1.2) +
  geom_point(data = p_clus_plot,
             aes(x = x, y = y), colour = "white",fill = "white", size = 1) +
  scale_fill_viridis_c(option = "magma",
                       limits = c(0, 0.4),
                       oob = scales::squish) +  
  facet_grid(mediator ~ outcome) +
  theme_void() + 
  theme(panel.grid   = element_blank(),
        axis.text    = element_blank(),
        axis.title   = element_blank(),
        axis.text.y  = element_blank(),
        axis.text.x  = element_blank(),
        strip.text.y = element_text(colour = "black",size = 6,angle = 0),
        strip.text.x = element_text(colour = "black",size = 6,angle = 0),
        legend.title = element_text(colour = "black",size = 6),
        legend.text  = element_text(colour = "black",size = 6)) +
  coord_equal() 

#Save the plots
ggsave("./Figures/eeg_mediation_topoplots.pdf",plot = p_med, width = 18, height = 28, units = "cm")


# Table 7 ======

# We also make a table of the average proportion mediated in each significant cluster

left_join(p_clus_plot |> rename(sig = statistic),
          p_med_plot) |> 
  group_by(outcome,mediator) |> 
  summarise(prop_mediated_mu = statistic |> abs() |> mean(),
            prop_mediated_sd = statistic |> abs() |> sd()) |> 
  ungroup() |>
  mutate(across(where(is.double),round,digits = 2)) |> 
  transmute(Outcome = outcome, 
            Mediator = mediator,
            `Proportion Mediated` = paste(prop_mediated_mu," (",prop_mediated_sd, ")",sep = "")) |>
  knitr::kable(format = "html", booktabs = TRUE) |>
  kableExtra::kable_styling(font_size = 11)
